{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy.io as sio\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize']=(15,5)\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'scipy.sparse.csc.csc_matrix'> (900, 600)\n",
      "<class 'numpy.ndarray'> (900, 1)\n"
     ]
    }
   ],
   "source": [
    "data_file = 'XwindowsDocData.mat'\n",
    "load_data = sio.loadmat(data_file)\n",
    "X_train=load_data['xtrain']\n",
    "print(type(X_train),X_train.shape)\n",
    "y_train=load_data['ytrain']\n",
    "print(type(y_train),y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取总共有多少个类\n",
    "C = np.unique(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.00442478 0.00221239 0.00663717 ... 0.00221239 0.00442478 0.02433628]\n",
      " [1.31028224 1.31028224 1.31028224 ... 1.31207097 1.31207097 1.30549658]]\n",
      "[0.5        1.16884762]\n",
      "[[0.00442478 0.00221239 0.00663717 ... 0.00221239 0.00442478 0.02433628]\n",
      " [0.01106195 0.01106195 0.01106195 ... 0.00442478 0.00442478 0.02876106]]\n",
      "[0.5 0.5]\n"
     ]
    }
   ],
   "source": [
    "pseudoCount = 1  # 伪计数为1\n",
    "N_train,D = X_train.shape\n",
    "# 初始化参数矩阵theta和类先验概率\n",
    "theta = np.empty((len(C),D))\n",
    "c_prior = np.empty(len(C))\n",
    "# 对X_train进行统计\n",
    "for c in C:\n",
    "    XX = X_train[y_train.flatten()==c]   # 获取属于c的所有样本\n",
    "    Non = np.sum(XX,0)           # 统计每个特征为1的数量\n",
    "    Noff = XX.shape[0]-Non\n",
    "    theta[c-1,:]=(Non+pseudoCount)/(Non+Noff+2*pseudoCount)\n",
    "    c_prior[c-1]= XX.shape[0]/N_train\n",
    "    # print(Non.shape,Noff.shape)\n",
    "    # print(Non.shape)\n",
    "    # print(XX.shape)\n",
    "    # print(c_pro)\n",
    "    print(theta)\n",
    "    print(c_prior)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5,1,'p(xj=1|y=2)')"
      ]
     },
     "execution_count": 141,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAE/CAYAAADVKysfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X20bGddH/DvjxsCLcFEyVUpCQZqLGRRebuNWBCvL9SAXckfpZosUbRoulYBdclqV1AXtVRbX9YStUZrxBe0RUTqyy0rNVokaGzBXOCCJDFyjdHcFTVRk1uviJjw9I+ZSSaTmXPmnDNnzjMzn89as87Ze57Z+3lm9n5++7efvWeqtRYAAAD69JiDrgAAAACzSdoAAAA6JmkDAADomKQNAACgY5I2AACAjknaAAAAOiZpY6NV1eOq6taq+sw5yn5rVb15weu/qKruXOQy51jnL1bVZctcJwCrRXyEvkja2HRXJ/nN1tqfblewtfafWmtfv9MVVNV1VXV7VX2yqr52N5Xc4fqeXFXHquruqmpVddFEke9O8l37XQ8AVto6xscvr6qbqur+qvrTqvrxqnriWBHxkW5J2th0/zrJz+7zOj6U5N8k+cA+r2fkk0l+Ncm/mPZka+13knxKVR1ZUn0AWD3rGB/PTfKdSf5BkmcmuSDJ942eFB/pmaSNtVdVd1bV64eXedxXVT9VVY+vqqcm+YdJ3jcsd3ZVnaiq1w6nD1XVb1fVG4bT31FV/22n62+tXdtae1eSj29Tz8+sqo9V1ZPG5j2/qu6tqsfuYH1/1lr7kSQ3b1HsxiRfPu8yAVg/Gxgf39pa+9XW2sdaa/cl+fEkL5wodmPERzokaWNTfFWSL8sgCH1Okm9P8o+T3NFaeyBJWmufSPKKJG+sqmcmuSbJocy4VGJ4ecWsxzU7reDwEpQbk3zF2OxXJHlba+3vqupF26zzRTtY3W1Jnr3TOgKwdjY5Pr44yS0T88RHunTWQVcAluSHW2t3JUlVfVeS/5JBx/xX44Vaax+pqu9M8ktJPiPJpa21B6ctsLV23j7U8y1JvjHJj1bVoSRXJbl8uL6bkixqnX+1wGUBsLo2Mj5W1UuSvDLJ5008JT7SJSNtbIq7xv7/owyuZ78vyROnlH1LkouSXN9a++j+V+0RfiXJJVX19CQvSXJ6eI39oj0xyf37sFwAVsvGxceqekGStyZ5eWvt9yeeFh/pkqSNTXHh2P9PTXJ3kg8neXpVTY44/0iSdyb5sq0uOayqM1s8vnU3lWytfTzJ2zO4XOWrM3YTeFV9wTbr/IIdrOqZGdwADsBm26j4WFXPTXIsyb8a3k83SXykSy6PZFO8uqremeRjSb41yc+31k5V1UeTXJrk/yRJVX11kudncD375UneUlXPbq2dmVxga+2ceVZcVWdncIKkkjy2qh6f5BOttU/OeMnPDB+fnuTbxtb3W0nmXefjM7jfIEkeV1WPHwa8kS/M4H4AADbbxsTHqnpWBt+u/NrW2v+cUUx8pEtG2tgUb03ya0nuGD6+czj/xzI4Y5fht2X9QJKvaa2daa29NcnxJG/a47p/LcnfJPmnSa4b/v/iWYVba7+dwdf2f6C1ducu1/k3SUaB9PeG00mSqvonSf56ny67BGC1bFJ8fF2Sw0l+YmwU7qEvIhEf6ZmRNjbFza21/zxl/puTfLCqntxa++MkTxp/srX2lWOTj0nyiZ2uuLV2dKevyeAeg7fu4nWjddYWT1+TsTOUAGy0jYmPrbWvS/J1WxQRH+mWpI2N1lr72ySXbFeuqmpYbt+vcx+e6Xtekiv2Y/mttak/ug0AI+Ij9GXbyyOr6ier6p6q+siM56uqfqiqTlbVh6vqeYuvJhy4DyS5IIMf4lyk+zO45CRJUlVvSfK/k3xza+2vZr4K6IIYCeIjLEO11rYuUPXiDO6N+ZnW2rOmPP+yJK9N8rIMfuviB1trk795AQBrR4wEYBm2HWlrrf1mkr/cosgVGQSr1lp7b5LzqurJi6ogAPRKjARgGRbx7ZFPySN/mPHUcB4AbDoxEoA9W8QXkUz7lrqp11xW1dVJrk6SJzzhCc9/xjOesYDVA9C797///X/eWjt80PU4AGIkADPNGx8XkbSdSnLh2PQFSe6eVrC1dl0Gv8ORI0eOtOPHjy9g9Zvl6NHB3xtvPMhaAOxMVf3RQdfhgIiRSyI+Aqto3vi4iMsjjyX5muE3ZL0gyenW2p8sYLkAsOrESAD2bNuRtqr6uSRHk5xfVaeS/Pskj02S1tp/TXJ9Bt+KdTLJx7L1jxYCwNoQIwFYhm2TttbaVds835K8emE1AoAVIUYCsAyLuDwSAACAfSJpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAADomKQNAACgY5I2AACAjknaAAAAOiZpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAADomKQNAACgY5I2AACAjknaAAAAOiZpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAADomKQNAACgY5I2AACAjknaAAAAOiZpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAADo2FxJW1VdVlW3V9XJqrpmyvNPrap3V9UHq+rDVfWyxVcVAPoiPgKwDNsmbVV1KMm1SV6a5JIkV1XVJRPFvj3J21trz01yZZIfWXRFAaAn4iMAyzLPSNulSU621u5orX0iyduSXDFRpiX5lOH/5ya5e3FVBIAuiY8ALMU8SdtTktw1Nn1qOG/cdyR5RVWdSnJ9ktdOW1BVXV1Vx6vq+L333ruL6gJANxYWHxMxEoDZ5knaasq8NjF9VZKfbq1dkORlSX62qh617Nbada21I621I4cPH955bQGgHwuLj4kYCcBs8yRtp5JcODZ9QR59ecerkrw9SVpr/zfJ45Ocv4gKAkCnxEcAlmKepO3mJBdX1dOq6uwMbqQ+NlHmj5N8SZJU1TMzCEqu7QBgnYmPACzFtklba+2BJK9JckOS2zL4FqxbquqNVXX5sNjrknxDVX0oyc8l+drW2uQlIgCwNsRHAJblrHkKtdauz+AG6vF5bxj7/9YkL1xs1QCgb+IjAMsw149rAwAAcDAkbQAAAB2TtAEAAHRM0gYAANAxSRsAAEDHJG0AAAAdk7QBAAB0TNIGAADQMUkbAABAxyRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAAAB2TtAEAAHRM0gYAANAxSRsAAEDHJG0AAAAdk7QBAAB0TNIGAADQMUkbAABAxyRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAAAB2TtAEAAHRM0gYAANAxSRsAAEDHJG0AAAAdk7QBAAB0TNIGAADQMUkbAABAxyRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAAAB2TtAEAwA4dPTp4wDJI2gAAADomaQMAAOiYpI2V5tIEAADWnaQNAACgY5I2AACAjknaAAAAOjZX0lZVl1XV7VV1sqqumVHmK6rq1qq6pareuthqAkB/xEcAluGs7QpU1aEk1yZ5SZJTSW6uqmOttVvHylyc5PVJXthau6+qPn2/KgwAPRAfAViWeUbaLk1ysrV2R2vtE0neluSKiTLfkOTa1tp9SdJau2ex1QSA7oiPACzFPEnbU5LcNTZ9ajhv3Ock+Zyq+u2qem9VXbaoCgJAp8RHAJZi28sjk9SUeW3Kci5OcjTJBUl+q6qe1Vq7/xELqro6ydVJ8tSnPnXHlQWAjiwsPiZiJACzzTPSdirJhWPTFyS5e0qZX2mt/V1r7Q+T3J5BkHqE1tp1rbUjrbUjhw8f3m2dAaAHC4uPiRgJwGzzJG03J7m4qp5WVWcnuTLJsYkyv5zki5Kkqs7P4HKQOxZZUQDojPgIwFJsm7S11h5I8pokNyS5LcnbW2u3VNUbq+ryYbEbkvxFVd2a5N1J/m1r7S/2q9IAcNDERwCWZZ572tJauz7J9RPz3jD2f0vyLcMHAGwE8RGAZZjrx7UBAAA4GJI2AACAjknaAAAAOiZpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAFhDR48OHsDqk7QBK8nBCAA8mvi4niRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAA7JkvwID9I2kDAADomKQNAAA2lBHS1SBpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAADomKQNAACgY5I22KGjRwcPAABYBkkbAABAxyRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAAAB2TtAEAAHRM0gYAANAxSRsAAEDHJG0AAAAdk7QBG+Ho0cEDAGDVSNoAAAA6JmkDAADomKQNAACgY3MlbVV1WVXdXlUnq+qaLcq9vKpaVR1ZXBUBoE/iIwDLsG3SVlWHklyb5KVJLklyVVVdMqXcE5N8Y5L3LbqSANAb8RGAZZlnpO3SJCdba3e01j6R5G1JrphS7j8m+d4kH19g/QCgV+IjAEsxT9L2lCR3jU2fGs57SFU9N8mFrbV3LrBuANAz8RGApZgnaasp89pDT1Y9Jsmbkrxu2wVVXV1Vx6vq+L333jt/LQGgPwuLj8PyYiQAU82TtJ1KcuHY9AVJ7h6bfmKSZyW5saruTPKCJMem3WzdWruutXaktXbk8OHDu681ABy8hcXHRIwEYLZ5krabk1xcVU+rqrOTXJnk2OjJ1trp1tr5rbWLWmsXJXlvkstba8f3pcYA0AfxEYCl2DZpa609kOQ1SW5IcluSt7fWbqmqN1bV5ftdQQDokfgIwLKcNU+h1tr1Sa6fmPeGGWWP7r1aANA/8RGAZZjrx7UBAAA4GJI2AACAjknaAAAAOiZpAwAA6JikDQAAoGOSNgAAgI5J2gAAADomaQMAAOiYpA0AAKBjkjYAAICOSdoAAAA6JmkDAKB7R48OHrCJJG0AAAAdk7QBAAB0TNIGAADQMUkbAABAxyRtAAAAHZO0AQAAdEzSBgAA0DFJGwAAQMckbQAAAB2TtAEsyNGjgwcAwCJJ2gAAADomaQMAAJbK1Sk7I2kDAGAhTpxwIA77QdIGAADQMUkbAABAxyRt0BnXeK8PnyUAsAiSNgAAYGOtwklWSRsAAEDHJG2wC74dCwCAZZG0AQAAdEzSxkpZhWuOAQBgkSRtAAAAHZO0AQCsEFedwOaRtAEAAHRM0gYAsAaMwMH6krQBAAB0TNIGACyUER+AxZK0AQAHSpIHsDVJGwAAQMckbQAAG85oJ/RN0gYAANAxSRsAAEDHJG0AAAAdk7QBAAB0bK6kraouq6rbq+pkVV0z5flvqapbq+rDVfWuqvqsxVcVAPoiPgKwDNsmbVV1KMm1SV6a5JIkV1XVJRPFPpjkSGvtc5O8I8n3LrqiANAT8RGAZZlnpO3SJCdba3e01j6R5G1Jrhgv0Fp7d2vtY8PJ9ya5YLHVBGAvfJ33vhAfAViKeZK2pyS5a2z61HDeLK9K8r+mPVFVV1fV8ao6fu+9985fSwDoz8LiYyJGAjDbPElbTZnXphasekWSI0m+b9rzrbXrWmtHWmtHDh8+PH8tAaA/C4uPiRgJMIurRZKz5ihzKsmFY9MXJLl7slBVfWmSb0vyha21v11M9QCgW+IjAEsxz0jbzUkurqqnVdXZSa5Mcmy8QFU9N8mPJbm8tXbP4qsJAN0RHwFYim2TttbaA0lek+SGJLcleXtr7ZaqemNVXT4s9n1JzknyC1V1oqqOzVgcHCjD68CiiI8ALMs8l0emtXZ9kusn5r1h7P8vXXC9ujc68L/xxoOsBT2wLcDmEh8BWIa5flwbgPmcOGE0FwBYLEkbAABAxyRtAAAAHZO0AQCsmaNHB5drA+thY5M23yIIAACsgo1N2gAAAFaBpA0AAKBjkjYAAICOSdoAAAA6JmkDiC8nAgD6JWmDDSEpAQBYTZI2AACAjknaAAAAJvR0lZKkDQAAVlBPScVurUMblkHSBgAA0DFJG7AlZ8D2znsIB89+CKyysw66ArAXJ04cdA0AAGB/GWkDAGCjzTsSa8SWgyJpAx5BQNosPm8AdkP8WC5JG0AHBD8AduPECfFjE0jaFsyBFwAAPJJj5L2RtAEAe+J+IID9JWkDAJZK8gawM5I2AAA4AE5gMC9JG3CgDipgCZTAKtBXAckG/rj2Ond8o7bdeONB1mJ3Vrnum8jnBcCqELNYB0ba9pGzY/TKtgkAsDo2bqQNdsNZOoBH28nJn6NHB78n9Zzn7FdtYDEWFfMdO7BIRtqAjbXIHyQdHZDCquppBL6HuvRQB4ARI20A+2iUyBld4KCtwln/VUmSVuG9ZH8ZOd45Jzb3xkgbsCfrfjZ6v9u37u8fzLLIkW6AdWekbU7OqgEAk5Z5fLCOSa7jK7ayjtv8bq3dSJuz1sBe6UcA2HSrHAtXue6zGGnbxn584M4qsR82Zbtat04Y2J1N6fNgXc1zz7f7wh8maWMqwXBgkxKETWorAH1zHLK392B0z+gmv3/rZu0uj1xV6ziMCwDAwerxGPPo0eTMmYOuxWox0raPfLXpYq3S1+v2dIawt476oNgf4dF66qu2s0p17cWZM/o+WBcbkbQto6N3YMyirdMBylZtWdcDip382PY81+yv0/ZAP3rc/w76HpbRvtbje7NMq97nHGT9D/qY8MQJo1jraCOSNmDnDjro9GQv9wbsV+Bc1AHJph+Ysv/0JcB+WfWTCzshaduhnQafM2cefbA3OkjayYa2jG+xXIcNfz/aMOusrx+G3d5e35912Cb3apHbmBvT6cFORqB3s71uWr8xb3v3430Zv21h1nKX/XmsUlw+qNs+Nm0fWRdr+UUk4wfTO7lEqRfjN4zu5ObREycW29b9TErmbdcibp49ejQ577xHL6fHG3OnWZV67pdVS45H95Ds5+Upq9ivsVoW3e9sFZ92sz3Pes2i4+Aq6aVfmOc+uk2Pazux1xg473axarF2E230SNtuzuId9LX2i7DfO+WqnsGZt96jg/Fzzln8sndiVd/nVbXoUSrBEQ5Ob/1nb/VZFbt93xb1fq/aF72IO4u17P12LUfaJi1qp9ruzPn4Gb7tvsr09Onkppse/fpF7FCrfgZrcmRsu/dldAnqqtnPz2ly2b2cfe512+zlDPXIrNHhg6jHQdeBxdpLnJlne1jnbWavbdvu9as40jE6LppV98m+dfI9mLfv3eq9W8RI1Kq973u1H8cEm/g+LttGj7T1ZtTxbTeKt6gz/vPusIs4k+As4vamdXb7MbJ7UNfQJzvf5qbNn1zGeNnx7Wu/trm9LHe3r93JdmBfY5reTkzsxdGjjzzpuV1/Met+q8lYuox7x+ep1zJs10+M3uOdXFEyy6wT57P68nnfk51s01uVW+YVVFttY7uJzbNi4k4vzd9qUGIRx5zTbk+56abkwQd3v8y91GPascIq9I8bMdKWPDwas4izC5Md0EEHw17PbozqddDvz1b248zmTtu708+v58973tGhvbbhzJlHB5hFf5bjy5u2vl4t6gxqz/stO7PdZ9n71QqL6PNGB4mL2qb34x7y3fQxq7yfzupXe41x43bz+a/yZ7WMUeDJ76MYX98iL0PdbVt62C7nGmmrqsuS/GCSQ0ne3Fr77onnH5fkZ5I8P8lfJPnK1tqdi63q9na6Qyz6WuStOtxFn/0+cWJwgDyy32eI5jnTstWlC+PvzXnnDaZf9KKd12H8Mox5v3Vzu894/L3czfs4fkns6P9Z29ZWZ9Imy4/K7GQ7neyM5t0ntvpMeggyo3bNc2Az7Yzebu5DnFWP8eXudn+e3A7mOVjbaptK5q/LvPtyD5/7KliV+DjuppsG+/z9929dbqdn2Mf7v2T+Pn58v57cV+cZBZksM+9VK/MYT2h3sszxfWi714yPjCz63tmd2uqAdhQnt9tu9rLuZPB+zXrvRtvHPAfPk33YtG/znve1I9v1jTv53PfDrPrN2if2KwnZSVyaFtu2e4/nWfdWdntMO2/fuV+2Tdqq6lCSa5O8JMmpJDdX1bHW2q1jxV6V5L7W2mdX1ZVJvifJV+5HhWeZNUS81cY7WXanpnVuDz746B1jdBnHOedsndjMGiae7MhOn04OHZp9ADqeRIxPT9vIZnVi0w6OZ23k0zqp8TaPL2erYfvJRHRWPc6ceXjHmVbHySC7G5NneKYZ1WN83bOWNW+wH63rnHO2v1wheeQytzrwGS17twcyZ84M3u/R5TzjbZ03QG2XOE128pPtn3x/H3xw/jPTZ848cv+a7Bt2epZ78h6OWR3/aJ2HDs2/7Fnr28llJOOfV7L1CYTxz0SytjM9x8ftTo5td+A7eZJsNC/Z/YHeeP+21f422rcXcWlesvWxwSjuTCYKo/ZPi+njcWz8vZiV1EyLbePHIaP4f+jQw3FlFGu3ukxx2oHu6dOzk5VRmcn9fLv9fjzWjh/HzLrkdNRXzVruIvqY8T57vP+aPPYYf99Gn+Xo9adPP/pzSQbzRsdZsw7sJ7ffaccMo7qM9sHROkef505i8XZJxpkzyXve83CsGX0G02L+eL8/K/aNby+nTz/8//hzsy75neeYZ9ox73h/M23/mrWcUVvGT75vddw9eUw6K5kd307Hn99J8r9o84y0XZrkZGvtjiSpqrcluSLJeFC6Isl3DP9/R5IfrqpqrbUF1nWqaZ3CqNMb7bSjznA74530oUOPDHazNuy9dj6zzmqNkpLdnqmZdoA+6wzQKKCMNvxpidVWZwHnHcUYT3BGneciDhLHl7VVgjV6P8bXO+qYx03bicc7+9F08ujXbvdFNaMyO+20p3US243AbXXgM/7ayfdtu89jVH6rkwbjB0GLSh5HgWO7fXm8/tPqOTlit9OEbStbJVfTDgjHTR5QTPt8Zx1oj/cjo+WOv8+TB8uT/SO71nV83M7onpJZ+9SDDz6cBGw18j9tP3vwwYfj2G7M6r/mjRfjJ9UWkfxN9hnj79v4ge25584+LhnVZXy/mxa3Zr13k8udTFq2q/+0tpw1cRQ4PsK5VV822RdNnsQcbSejtp4+PVjXdnUd377GE9jtbJd0jj6fM2cG9Zhs2+nTg6Tn3HMfvdzxBHqrgYBR3z65T806zttuWdPKjx8XTh5T7dSsuLfViYLx46Px5Ha0XW+VAO3kypJJo/5g1uc8vszxfm2e463JbXe0T4zvq+P1nzweXKbaLm5U1cuTXNZa+/rh9Fcn+bzW2mvGynxkWObUcPoPhmX+fNZyjxw50o4fP77rio8feE0eQM8zvV9lD+q1q1DHdW+fOh78a9Vx8XU899zFXApSVe9vrR3Z+5L6sV/xMdl7jBwfHVnHbXQvr1VHdVyX16rjwdYxGQzw7HXUbd74OM9IW02ZN5npzVMmVXV1kquHk2eq6vY51r+V85M8FPhGb+Y80/tVdp9ee36SP++8jot47czPs6M6LqLsoz7PDuu419ee/+CDu9s397LeA9hOtt03O6jjnsqePp1UPXLf3KXP2uPre7Sw+Jjsb4zscTtb4Gsf0d90Wsd9i48d1XERr13o8U7Hn/VCjnc6bt/IjvbNA6rjntbznvfk/KrlxMd5krZTSS4cm74gyd0zypyqqrOSnJvkLycX1Fq7Lsl181RsHlV1fN3O3E6jnetlE9q5CW1MtJPFxcdEjNytTWjnJrQx0c51swntXGYbHzNHmZuTXFxVT6uqs5NcmeTYRJljSV45/P/lSX6jh+v1AWAfiY8ALMW2I22ttQeq6jVJbkhyKMlPttZuqao3JjneWjuW5CeS/GxVnczgDOKV+1lpADho4iMAyzLX77S11q5Pcv3EvDeM/f/xJP9ysVWby8IuI+mcdq6XTWjnJrQx0c6N13F8TDbnc9uEdm5CGxPtXDeb0M6ltXHbb48EAADg4MxzTxsAAAAHZGWTtqq6rKpur6qTVXXNQddnL6rqJ6vqnuHv+YzmfVpV/XpVfXT491OH86uqfmjY7g9X1fMOrubzq6oLq+rdVXVbVd1SVd80nL9u7Xx8Vf1OVX1o2M7/MJz/tKp637CdPz/80oJU1eOG0yeHz190kPXfiao6VFUfrKp3DqfXsY13VtXvVtWJqjo+nLdW22ySVNV5VfWOqvq94T76+evYzk0hPq7etrkJMXKT4mMiRq7DNjvSS4xcyaStqg4luTbJS5NckuSqqrrkYGu1Jz+d5LKJedckeVdr7eIk7xpOJ4M2Xzx8XJ3kR5dUx716IMnrWmvPTPKCJK8efmbr1s6/TfLFrbVnJ3lOksuq6gVJvifJm4btvC/Jq4blX5XkvtbaZyd507DcqvimJLeNTa9jG5Pki1przxn7St9122aT5AeT/Gpr7RlJnp3B57qO7Vx74uPKbpubECM3KT4mYuQ6bLMjfcTI1trKPZJ8fpIbxqZfn+T1B12vPbbpoiQfGZu+PcmTh/8/Ocntw/9/LMlV08qt0iPJryR5yTq3M8nfT/KBJJ+XwY9onjWc/9D2m8G3zn3+8P+zhuXqoOs+R9suyKCT+uIk78zgB4TXqo3D+t6Z5PyJeWu1zSb5lCR/OPmZrFs7N+UhPq7HtrnuMXKd4+OwvmJkW49ttqcYuZIjbUmekuSuselTw3nr5DNaa3+SJMO/nz6cv/JtHw79PzfJ+7KG7RxeEnEiyT1Jfj3JHyS5v7X2wLDIeFseaufw+dNJnrTcGu/KDyT5d0k+OZx+UtavjUnSkvxaVb2/qq4ezlu3bfbpSe5N8lPDS3neXFVPyPq1c1Nswuez1tvmOsfIDYmPiRi5NttsOoqRq5q01ZR5m/I1mCvd9qo6J8n/SPLNrbX/t1XRKfNWop2ttQdba8/J4EzbpUmeOa3Y8O/KtbOq/nmSe1pr7x+fPaXoyrZxzAtba8/L4HKHV1fVi7cou6rtPCvJ85JYLIcEAAACXUlEQVT8aGvtuUn+Og9f5jHNqrZzU2zy57PybV/3GLnu8TERI7cou6rt7CZGrmrSdirJhWPTFyS5+4Dqsl/+rKqenCTDv/cM569s26vqsRkEo//eWvvF4ey1a+dIa+3+JDdmcH/CeVU1+l3E8bY81M7h8+dm8AO8PXthksur6s4kb8vg8o8fyHq1MUnSWrt7+PeeJL+UwUHGum2zp5Kcaq29bzj9jgwC1Lq1c1NswuezltvmJsXINY6PiRi5bttsNzFyVZO2m5NcPPwmnrOTXJnk2AHXadGOJXnl8P9XZnB9+2j+1wy/neYFSU6Phmd7VlWV5CeS3NZa+/6xp9atnYer6rzh/38vyZdmcMPqu5O8fFhssp2j9r88yW+04UXQvWqtvb61dkFr7aIM9r3faK19VdaojUlSVU+oqieO/k/yz5J8JGu2zbbW/jTJXVX1j4azviTJrVmzdm4Q8XEFt81NiJGbEB8TMTJrtM0mncXIg7qxb6+PJC9L8vsZXA/9bQddnz225eeS/EmSv8sgQ39VBtczvyvJR4d/P21YtjL4ZrA/SPK7SY4cdP3nbOOLMhge/nCSE8PHy9awnZ+b5IPDdn4kyRuG85+e5HeSnEzyC0keN5z/+OH0yeHzTz/oNuywvUeTvHMd2zhsz4eGj1tG/cy6bbPDuj8nyfHhdvvLST51Hdu5KQ/xcfW2zU2IkZsWH4dtECNXeJsda2sXMbKGKwAAAKBDq3p5JAAAwEaQtAEAAHRM0gYAANAxSRsAAEDHJG0AAAAdk7QBAAB0TNIGAADQMUkbAABAx/4/Ez6u+Zz3ZOUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1080x360 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ax1 = plt.subplot(121)\n",
    "ax1.vlines(np.arange(theta.shape[-1]),0,theta[0,:],colors='blue')\n",
    "ax1.set_ylim(0,1)\n",
    "ax1.set_title('p(xj=1|y=1)')\n",
    "ax2 = plt.subplot(122)\n",
    "ax2.vlines(np.arange(theta.shape[-1]),0,theta[1,:],colors='blue')\n",
    "ax2.set_ylim(0,1)\n",
    "ax2.set_title('p(xj=1|y=2)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(900, 600)\n"
     ]
    }
   ],
   "source": [
    "X_test = load_data['xtest'].toarray()\n",
    "print(X_test.shape)\n",
    "y_test = load_data['ytest']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 600)\n"
     ]
    }
   ],
   "source": [
    "Ntest = X_test.shape[0]\n",
    "logPrior = np.log(c_prior)\n",
    "logPost = np.empty((len(C),Ntest))\n",
    "logTrue = np.log(theta)\n",
    "logTrueNot = np.log(1-theta)\n",
    "XtestNot = 1-X_test\n",
    "XtrainNot = 1-X_train.toarray()\n",
    "print(logTrue.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(900,)\n",
      "(900,)\n",
      "(900,)\n",
      "(900,)\n",
      "[[-41.18813634 -23.01671991 -71.89830658 ... -43.85026596 -26.02082748\n",
      "  -39.14087707]\n",
      " [-46.14089735 -27.50579873 -79.7469371  ... -38.79526696 -22.47443335\n",
      "  -40.55321972]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.18666666666666665"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for c in C:\n",
    "    L_true = np.dot(logTrue[c-1,:],X_test.T)\n",
    "    print(L_true.shape)\n",
    "    L_false = np.dot(logTrueNot[c-1,:],XtestNot.T)\n",
    "    print(L_false.shape)\n",
    "    logPost[c-1,:]=L_true+L_false+logPrior[c-1]\n",
    "print(logPost)\n",
    "y_hat = np.argmax(logPost,0)+1\n",
    "1-y_hat[y_hat==y_test.flatten()].size/y_hat.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(600,)\n"
     ]
    }
   ],
   "source": [
    "# 求取MI\n",
    "theta_j = np.dot(c_prior[np.newaxis,:],theta)\n",
    "I=np.empty((len(C),D))\n",
    "for c in C:\n",
    "    I[c-1,:]=theta[c-1,:]*c_prior[c-1]*np.log(theta[c-1,:]/theta_j)+ \\\n",
    "    (1-theta[c-1,:])*c_prior[c-1]*np.log((1-theta[c-1,:])/(1-theta_j))\n",
    "MI = np.sum(I,0)\n",
    "print(MI.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = load_data['vocab']\n",
    "type(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class1</th>\n",
       "      <th>prob</th>\n",
       "      <th>class2</th>\n",
       "      <th>prob</th>\n",
       "      <th>highest MI</th>\n",
       "      <th>MI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>subject</td>\n",
       "      <td>0.997788</td>\n",
       "      <td>subject</td>\n",
       "      <td>0.997788</td>\n",
       "      <td>windows</td>\n",
       "      <td>0.149046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>this</td>\n",
       "      <td>0.639381</td>\n",
       "      <td>windows</td>\n",
       "      <td>0.639381</td>\n",
       "      <td>microsoft</td>\n",
       "      <td>0.066171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>with</td>\n",
       "      <td>0.539823</td>\n",
       "      <td>this</td>\n",
       "      <td>0.539823</td>\n",
       "      <td>dos</td>\n",
       "      <td>0.063858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>but</td>\n",
       "      <td>0.537611</td>\n",
       "      <td>with</td>\n",
       "      <td>0.537611</td>\n",
       "      <td>motif</td>\n",
       "      <td>0.054201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>you</td>\n",
       "      <td>0.517699</td>\n",
       "      <td>but</td>\n",
       "      <td>0.517699</td>\n",
       "      <td>window</td>\n",
       "      <td>0.046677</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    class1      prob   class2      prob highest MI        MI\n",
       "0  subject  0.997788  subject  0.997788    windows  0.149046\n",
       "1     this  0.639381  windows  0.639381  microsoft  0.066171\n",
       "2     with  0.539823     this  0.539823        dos  0.063858\n",
       "3      but  0.537611     with  0.537611      motif  0.054201\n",
       "4      you  0.517699      but  0.517699     window  0.046677"
      ]
     },
     "execution_count": 139,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cols=['class1','prob','class2','prob','highest MI','MI']\n",
    "first_col = [element[0] for element in vocab[np.argsort(-theta[0,:])[0:5]].flatten()]\n",
    "second_col = -np.sort(-theta[0,:])[0:5]\n",
    "third_col = [element[0] for element in vocab[np.argsort(-theta[1,:])[0:5]].flatten()]\n",
    "forth_col = -np.sort(-theta[1,:])[0:5]\n",
    "fivth_col = [element[0] for element in vocab[np.argsort(-MI)[0:5]].flatten()]\n",
    "sixth_col = -np.sort(-MI)[0:5]\n",
    "df = pd.DataFrame({'class1':first_col,'prob':second_col,'class2':third_col,'prob':forth_col,\n",
    "                   'highest MI':fivth_col,'MI':sixth_col}\n",
    ")\n",
    "df.loc[:,cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
